# 0. 引入必要的包
import os
import numpy as np
from sklearn.model_selection import train_test_split
from util import get, preprocess_image, dump
from tqdm import tqdm

# 1. 读取配置文件中的信息
train_dir = get("train")  # 获取 训练数据路径
char_styles = get("char_styles")  # 获取 字符样式列表，注意: 必须是列表
new_size = get("new_size")  # 获取 新图像大小元组, 注意: 必须包含h和w
Xy_dir = get('Xy_root')
# 2. 生成X,y 
print("# 读取训练数据并进行预处理，")
X = []
y = []
y_labels = {'篆书': 0, '隶书': 1, '草书': 2, '行书': 3, '楷书': 4}
data_dict = {}
data = []
# 遍历train文件夹中的所有文件+
for style in char_styles:
    for root, dirs, files in os.walk(train_dir):
        for file_name in files:
            if file_name.startswith('train_') and file_name.endswith('.jpg'):
                # 解析文件名以提取书法种类
                parts = file_name.split('_')
                if len(parts) == 3:
                    calligraphy_style = parts[1]
                    if calligraphy_style == style:
                        file_path = os.path.join(root, file_name)
                        data.append(file_path)
    data_dict[style] = data
    data = []


def process_image(names, paths):
    for file_paths in tqdm(paths, ascii=True, desc=f'处理 {names} 图像'):
        try:
            img = preprocess_image(file_paths, new_size)
            X.append(img)
            y.append(y_labels[names])
        except ValueError as e:
            print(e)


for name, path in data_dict.items():
    process_image(name, path)

X = np.array(X, dtype=np.float64)
y = np.array(y, dtype=np.int64)

# 3. 分割测试集和训练集
print("# 将数据按 80% 和 20% 的比例分割")
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, train_size=0.8)

# 4. 打印样本维度和类型信息
print("X_train: ", X_train.shape, X_train.dtype)  # 训练集特征的维度和类型
print("X_test: ", X_test.shape, X_test.dtype)  # 测试集特征的维度和类型
print("y_train: ", y_train.shape, y_train.dtype)  # 训练集标签的维度和类型
print("y_test: ", y_test.shape, y_test.dtype)  # 测试集标签的维度和类型

# 5. 序列化分割后的训练和测试样本
dump((X_train, X_test, y_train, y_test), "序列化分割后的训练和测试样本", Xy_dir + '/Xy')
